# ------------------------------------------------------------------------------
# Calculates life expectancy based on comorbidities
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bsub -q big -R "rusage[mem=10000]" bash 10_calc_life_expectancy.sh
# ------------------------------------------------------------------------------

# Setup ------------------------------------------------------------------------
library(yaml) # read_yaml absolute filepaths
library(data.table)
library(here) # here() relative filepaths
library(testit) # assert function
library(tidyverse)
library(glue)
library(matrixStats)

u <- modules::use(here("lib", "util.R"))
temp <- here("code", "02_prep_and_summarize_cohort", "temp")

# Helper Functions -------------------------------------------------------------
get_age_weight <- function(sex_female, age_centered) {
  age_multiplier <- case_when(
    sex_female == 1 ~ age_hr_women^age_centered,
    sex_female == 0 ~ age_hr_men^age_centered
  )
}

get_life_expectancy <- function(
                                sex_female, comorbidity_hazard, survival, cum_survival = 1, ...) {
  cum_survival_df <- survival %>%
    mutate(
      hazard = case_when(
        sex_female == 1 ~ baseline_hazard_female * comorbidity_hazard,
        sex_female == 0 ~ baseline_hazard_male * comorbidity_hazard
      ),
      survival_period = 1 - hazard,
      survival_cum = cumprod(survival_period) * cum_survival
    )

  est_dead_months <- cum_survival_df$months[which(cum_survival_df$survival_cum < .50)]
  if (length(est_dead_months) > 0) {
    # life exp < 10 years
    life_exp <- min(est_dead_months)
  } else {
    # recursive -- continue the process for a patient 10 yrs older
    surv_rate_women <- survival$p_female %>% min()
    surv_rate_men <- survival$p_male %>% min()
    add_months <- max(survival$months)

    # adjust from original survival numbers
    survival_recur <- survival_df %>%
      mutate(
        months = months + add_months,
        p_female = p_female * surv_rate_women,
        p_male = p_male * surv_rate_men
      ) %>%
      mutate(
        p_lag_female = lag(p_female) %>% replace_na(surv_rate_women),
        p_lag_male = lag(p_male) %>% replace_na(surv_rate_men),
        baseline_hazard_female = 1 - p_female / p_lag_female,
        baseline_hazard_male = 1 - p_male / p_lag_male
      )

    age_adjustment <- get_age_weight(sex_female, 10)
    hazard_recur <- comorbidity_hazard * age_adjustment
    cum_survival <- cum_survival_df$survival_cum %>% min()

    life_exp <- get_life_expectancy(
      sex_female, hazard_recur, survival_recur,
      cum_survival = cum_survival
    )
  }
}

# Load Data --------------------------------------------------------------------
message("Loading data...")
paths <- read_yaml(here("lib", "filepaths.yml"))
overnight_lab <- ""
load_fp <- file.path(temp, "comorbidities_df.rds")
comorbidities <- readRDS(load_fp) %>%
  mutate(age_centered = age_at_admit - 66)
survival_raw <- read_csv(here("lib", "survival_estimates.csv"))

# Get Multiplier ---------------------------------------------------------------
message("Calculating multiplier...")
age_hr_women <- 1.107
age_hr_men <- 1.094

wt_vars <- names(comorbidities)[grepl("weight_", names(comorbidities))]
partial_hazards <- comorbidities %>%
  select(all_of(wt_vars)) %>%
  as.matrix() %>%
  rowProds()
assert("All hazards > 0", all(partial_hazards > 0))

comorbidities <- comorbidities %>%
  mutate(
    age_multiplier = case_when(
      sex_female == 1 ~ age_hr_women^age_centered,
      sex_female == 0 ~ age_hr_men^age_centered
    )
  ) %>%
  mutate(partial_hazard = partial_hazards) %>%
  mutate(comorbidity_hazard = partial_hazard * age_multiplier)

# Prep survival ----------------------------------------------------------------
message("Prepping survival table for use...")
survival_df <- survival_raw %>%
  mutate(
    p_lag_female = lag(p_female) %>% replace_na(1),
    p_lag_male = lag(p_male) %>% replace_na(1),
    baseline_hazard_female = 1 - p_female / p_lag_female,
    baseline_hazard_male = 1 - p_male / p_lag_male
  )

# Life Expectancy --------------------------------------------------------------
message("Getting life expectancy...")
# To get life expectancy:
# (We will use median survival as life expectancy)
# - multiply baseline_hazard_male or baseline_hazard_female by patient's comorbidity hazard
# - then take 1 - hazard in each period to get prob(survive) in that period
# - then take cumprod() of period-level survival to get individual surival estimates by month
# - LE = first month where survival prob drops below 50%

# comorbidities <- comorbidities[1:10,]
life_exp_df <- comorbidities %>%
  mutate(
    life_expectancy = pmap(., get_life_expectancy, survival = survival_df) %>%
      unlist()
  ) %>%
  mutate(ly_remaining = life_expectancy / 12) %>%
  select(ed_enc_id, ly_remaining)

# Save -------------------------------------------------------------------------
message("Saving...")
write_rds(life_exp_df, paths$analysis$life_expectancy)

message("Done.")
